{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import argparse\n",
        "from pathlib import Path\n",
        "import pickle\n",
        "import mlflow\n",
        "\n",
        "import os \n",
        "import json"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "jupyter": {
          "outputs_hidden": false,
          "source_hidden": false
        },
        "nteract": {
          "transient": {
            "deleting": false
          }
        }
      },
      "outputs": [],
      "source": [
        "# Define Arguments for this step\n",
        "\n",
        "class MyArgs:\n",
        "    def __init__(self, /, **kwargs):\n",
        "        self.__dict__.update(kwargs)\n",
        "\n",
        "args = MyArgs(\n",
        "                model_name = \"taxi-model\",\n",
        "                model_path = \"/tmp/train\",\n",
        "                evaluation_output = \"/tmp/evaluate\", \n",
        "                model_info_output_path = \"/tmp/model_info_output_path\"\n",
        "                )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "jupyter": {
          "outputs_hidden": false,
          "source_hidden": false
        },
        "nteract": {
          "transient": {
            "deleting": false
          }
        }
      },
      "outputs": [],
      "source": [
        "def main(args):\n",
        "    '''Loads model, registers it if deply flag is True'''\n",
        "    \n",
        "    with open((Path(args.evaluation_output) / \"deploy_flag\"), 'rb') as infile:\n",
        "        deploy_flag = int(infile.read())\n",
        "\n",
        "    mlflow.log_metric(\"deploy flag\", int(deploy_flag))\n",
        "    \n",
        "    if deploy_flag==1:\n",
        "\n",
        "        print(\"Registering \", args.model_name)\n",
        "\n",
        "        # load model\n",
        "        model =  mlflow.sklearn.load_model(args.model_path) \n",
        "\n",
        "        # log model using mlflow\n",
        "        mlflow.sklearn.log_model(model, args.model_name)\n",
        "\n",
        "        # register logged model using mlflow\n",
        "        run_id = mlflow.active_run().info.run_id\n",
        "        model_uri = f'runs:/{run_id}/{args.model_name}'\n",
        "        mlflow_model = mlflow.register_model(model_uri, args.model_name)\n",
        "        model_version = mlflow_model.version\n",
        "\n",
        "        # write model info\n",
        "        print(\"Writing JSON\")\n",
        "        dict = {\"id\": \"{0}:{1}\".format(args.model_name, model_version)}\n",
        "        output_path = os.path.join(args.model_info_output_path, \"model_info.json\")\n",
        "        with open(output_path, \"w\") as of:\n",
        "            json.dump(dict, fp=of)\n",
        "\n",
        "    else:\n",
        "        print(\"Model will not be registered!\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "jupyter": {
          "outputs_hidden": false,
          "source_hidden": false
        },
        "nteract": {
          "transient": {
            "deleting": false
          }
        }
      },
      "outputs": [],
      "source": [
        "mlflow.start_run()\n",
        "\n",
        "lines = [\n",
        "    f\"Model name: {args.model_name}\",\n",
        "    f\"Model path: {args.model_path}\",\n",
        "    f\"Evaluation output path: {args.evaluation_output}\",\n",
        "]\n",
        "\n",
        "for line in lines:\n",
        "    print(line)\n",
        "\n",
        "main(args)\n",
        "\n",
        "mlflow.end_run()"
      ]
    }
  ],
  "metadata": {
    "kernel_info": {
      "name": "local-env"
    },
    "kernelspec": {
      "display_name": "Python 3.9.6 64-bit",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.6"
    },
    "nteract": {
      "version": "nteract-front-end@1.0.0"
    },
    "vscode": {
      "interpreter": {
        "hash": "c87d6401964827bd736fe8e727109b953dd698457ca58fb5acabab22fd6dac41"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
